/**
 * Copyright 2024-2024 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "model_build_ndk_options_util.h"
#include "securec.h"
#include "framework/infra/log/log.h"
#include "model_manager/general_model_manager/ndk/hiai_ndk/hiai_ndk_align.h"
#include "model_manager/general_model_manager/ndk/hiai_ndk/hiai_ndk_nncore.h"
#include "model_manager/general_model_manager/ndk/hiai_ndk/hiai_ndk_options.h"

#include "util/base_types.h"

namespace hiai {
namespace {
void DestroyNNTensorDescs(std::vector<NN_TensorDesc*>& descs)
{
    for (size_t i = 0; i < descs.size(); i++) {
        if (descs[i] != nullptr) {
            HIAI_NDK_NNTensorDesc_Destroy(&descs[i]);
            descs[i] = nullptr;
        }
    }
}

std::vector<NN_TensorDesc*> ConvertToNNTensorDescs(const std::vector<NDTensorDesc>& descs)
{
    std::vector<NN_TensorDesc*> nnTensorDescs;
    for (size_t i = 0; i < descs.size(); i++) {
        if (descs[i].dims.empty()) {
            return nnTensorDescs;
        }
        nnTensorDescs.push_back(HIAI_NDK_NNTensorDesc_Create()); // 调用完HIAI_NDK_HiAIOptions_SetInputTensorShapes接口后释放
        HIAI_EXPECT_TRUE_R(nnTensorDescs[i] != nullptr, nnTensorDescs);

        Status ret = HIAI_NDK_NNTensorDesc_SetShape(nnTensorDescs[i], descs[i].dims.data(), descs[i].dims.size());
        HIAI_EXPECT_TRUE_R(ret == SUCCESS, nnTensorDescs);

        ret = HIAI_NDK_NNTensorDesc_SetDataType(nnTensorDescs[i],
            HIAIAlign::ConvertHIAIDataTypeToNN(static_cast<HIAI_DataType>(descs[i].dataType)));
        HIAI_EXPECT_TRUE_R(ret == SUCCESS, nnTensorDescs);

        ret = HIAI_NDK_NNTensorDesc_SetFormat(nnTensorDescs[i],
            HIAIAlign::ConvertHIAIFormatToNN(static_cast<HIAI_Format>(descs[i].format)));
        HIAI_EXPECT_TRUE_R(ret == SUCCESS, nnTensorDescs);
    }
    return nnTensorDescs;
}

Status SetDynamicShapeConfig(OH_NNCompilation* &compilation, const DynamicShapeConfig& config)
{
    HiAI_DynamicShapeStatus dynamicShapeStatus =
        config.enable ? HIAI_DYNAMIC_SHAPE_ENABLED : HIAI_DYNAMIC_SHAPE_DISABLED;
    Status retCode = HIAI_NDK_HiAIOptions_SetDynamicShapeStatus(compilation, dynamicShapeStatus);
    HIAI_EXPECT_TRUE(retCode == SUCCESS);

    retCode = HIAI_NDK_HiAIOptions_SetDynamicShapeMaxCache(compilation, config.maxCachedNum);
    HIAI_EXPECT_TRUE(retCode == SUCCESS);

    HiAI_DynamicShapeCacheMode cacheMode = static_cast<HiAI_DynamicShapeCacheMode>(config.cacheMode);
    retCode = HIAI_NDK_HiAIOptions_SetDynamicShapeCacheMode(compilation, cacheMode);
    HIAI_EXPECT_TRUE(retCode == SUCCESS);
    return SUCCESS;
}

std::vector<HiAI_ExecuteDevice> ConvertToHiAIExecuteDevice(const std::vector<ExecuteDevice>& executeDevices)
{
    std::vector<HiAI_ExecuteDevice> devices;
    HIAI_EXPECT_TRUE_R(!executeDevices.empty(), devices);

    for (size_t i = 0; i < executeDevices.size(); ++i) {
        devices.push_back(static_cast<HiAI_ExecuteDevice>(executeDevices[i]));
    }
    return devices;
}

Status SetModelDeviceConfig(OH_NNCompilation* &compilation, const ModelDeviceConfig& config)
{
    Status retCode = HIAI_NDK_HiAIOptions_SetFallbackMode(compilation,
        static_cast<HiAI_FallbackMode>(config.fallBackMode));
    HIAI_EXPECT_TRUE(retCode == SUCCESS);

    if (config.deviceConfigMode == DeviceConfigMode::MODEL_LEVEL && !config.modelDeviceOrder.empty()) {
        std::vector<HiAI_ExecuteDevice> devices = ConvertToHiAIExecuteDevice(config.modelDeviceOrder);
        HIAI_EXPECT_TRUE(devices.size() == config.modelDeviceOrder.size());
        retCode = HIAI_NDK_HiAIOptions_SetModelDeviceOrder(compilation, devices.data(), devices.size());
        HIAI_EXPECT_TRUE(retCode == SUCCESS);
    } else if (config.deviceConfigMode == DeviceConfigMode::OP_LEVEL && !config.opDeviceOrder.empty()) {
        for (const auto &iter : config.opDeviceOrder) {
            std::string opName = iter.first;
            std::vector<ExecuteDevice> executeDevices = iter.second;
            std::vector<HiAI_ExecuteDevice> devices = ConvertToHiAIExecuteDevice(executeDevices);
            HIAI_EXPECT_TRUE(devices.size() == executeDevices.size());
            retCode = HIAI_NDK_HiAIOptions_SetOperatorDeviceOrder(compilation,
                opName.c_str(), devices.data(), devices.size());
            HIAI_EXPECT_TRUE(retCode == SUCCESS);
        }
    }

    retCode = HIAI_NDK_HiAIOptions_SetDeviceMemoryReusePlan(compilation,
        static_cast<HiAI_DeviceMemoryReusePlan>(config.deviceMemoryReusePlan));
    HIAI_EXPECT_TRUE(retCode == SUCCESS);
    return SUCCESS;
}
}

Status ModelBuildNDKOptionsUtil::AssignToNNCompilation(OH_NNCompilation* &compilation, const ModelBuildOptions& options)
{
    Status retCode = FAILURE;
    if (!options.inputTensorDescs.empty()) {
        std::vector<NN_TensorDesc*> inputDescs = ConvertToNNTensorDescs(options.inputTensorDescs);
        if (inputDescs.size() != options.inputTensorDescs.size()) {
            DestroyNNTensorDescs(inputDescs);
            return FAILURE;
        }
        retCode = HIAI_NDK_HiAIOptions_SetInputTensorShapes(compilation,
            inputDescs.data(), inputDescs.size());
        DestroyNNTensorDescs(inputDescs);
        HIAI_EXPECT_TRUE(retCode == SUCCESS);
    }

    if (options.formatMode != FormatMode::USE_NCHW) {
        retCode = HIAI_NDK_HiAIOptions_SetFormatMode(compilation,
            static_cast<HiAI_FormatMode>(options.formatMode));
        HIAI_EXPECT_TRUE(retCode == SUCCESS);
    }

    if (options.precisionMode != PrecisionMode::PRECISION_MODE_FP32) {
        bool enableFloat16 = options.precisionMode == PRECISION_MODE_FP16;
        retCode = HIAI_NDK_NNCompilation_EnableFloat16(compilation, enableFloat16);
        HIAI_EXPECT_TRUE(retCode == SUCCESS);
    }

    if (options.dynamicShapeConfig.enable) {
        retCode = SetDynamicShapeConfig(compilation, options.dynamicShapeConfig);
        HIAI_EXPECT_TRUE(retCode == SUCCESS);
    }

    if (options.modelDeviceConfig.fallBackMode != FallBackMode::ENABLE ||
        !options.modelDeviceConfig.modelDeviceOrder.empty() || !options.modelDeviceConfig.opDeviceOrder.empty()) {
        retCode = SetModelDeviceConfig(compilation, options.modelDeviceConfig);
        HIAI_EXPECT_TRUE(retCode == SUCCESS);
    }

    if (options.tuningStrategy != TuningStrategy::OFF) {
        HiAI_TuningStrategy tuningStrategy = static_cast<HiAI_TuningStrategy>(options.tuningStrategy);
        retCode = HIAI_NDK_HiAIOptions_SetTuningStrategy(compilation, tuningStrategy);
        HIAI_EXPECT_TRUE(retCode == SUCCESS);
    }

    if (!options.quantizeConfig.empty()) {
        retCode = HIAI_NDK_HiAIOptions_SetQuantConfig(compilation,
            const_cast<char*>(options.quantizeConfig.data()), options.quantizeConfig.size());
        HIAI_EXPECT_TRUE(retCode == SUCCESS);
    }

    if (options.tuningConfig.tuningMode != TuningMode::UNSET && options.tuningConfig.cacheDir.empty()) {
        FMK_LOGE("tuning mode enable, but cacheDir is empty.");
        return FAILURE;
    }

    if (!options.tuningConfig.cacheDir.empty()) {
        HiAI_TuningMode tuningMode = static_cast<HiAI_TuningMode>(options.tuningConfig.tuningMode);
        retCode = HIAI_NDK_HiAIOptions_SetTuningMode(compilation, tuningMode);
        HIAI_EXPECT_TRUE(retCode == SUCCESS);
        retCode = HIAI_NDK_HiAIOptions_SetTuningCacheDir(compilation, options.tuningConfig.cacheDir.c_str());
        HIAI_EXPECT_TRUE(retCode == SUCCESS);
        HiAI_DeviceMemoryReusePlan deviceMemoryReusePlan = static_cast<HiAI_DeviceMemoryReusePlan>(
            options.tuningConfig.deviceMemoryReusePlan);
        retCode = HIAI_NDK_HiAIOptions_SetDeviceMemoryReusePlan(compilation, deviceMemoryReusePlan);
        HIAI_EXPECT_TRUE(retCode == SUCCESS);
    }

    if (!options.customOpPath.empty()) {
        retCode = HIAI_NDK_HiAIOptions_SetCustomOpPath(compilation, options.customOpPath.c_str());
        HIAI_EXPECT_TRUE(retCode == SUCCESS);
    }

    return SUCCESS;
}

}